{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "j-Iyf5gv5oBq"
   },
   "source": [
    "# Preprocessing Data with Simple Example using TensorFlow Transform"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "mPt5BHTwy_0F"
   },
   "source": [
    "### Learning objectives\n",
    "\n",
    "* Create a preprocessing function.\n",
    "* Use the resulting `transform_fn` directory.\n",
    "* Export the model.\n",
    "\n",
    "In this notebook, you learn a very simple example of how <a target='_blank' href='https://www.tensorflow.org/tfx/transform/get_started/'>TensorFlow Transform</a> (<code>tf.Transform</code>) can be used to preprocess data using exactly the same code for both training a model and serving inferences in production.\n",
    "\n",
    "TensorFlow Transform is a library for preprocessing input data for TensorFlow, including creating features that require a full pass over the training dataset.  For example, using TensorFlow Transform you could:\n",
    "\n",
    "* Normalize an input value by using the mean and standard deviation.\n",
    "* Convert strings to integers by generating a vocabulary over all of the input values.\n",
    "* Convert floats to integers by assigning them to buckets, based on the observed data distribution.\n",
    "\n",
    "TensorFlow has built-in support for manipulations on a single example or a batch of examples. `tf.Transform` extends these capabilities to support full passes over the entire training dataset.\n",
    "\n",
    "The output of `tf.Transform` is exported as a TensorFlow graph which you can use for both training and serving. Using the same graph for both training and serving can prevent skew, since the same transformations are applied in both stages.\n",
    "\n",
    "Each learning objective will correspond to a __#TODO__ in the notebook, where you will complete the notebook cell's code before running the cell. Refer to the [solution notebook](../solutions/simple.ipynb) for reference."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "6c8lD3uQm8m5"
   },
   "source": [
    "### Upgrade Pip"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 23,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/"
    },
    "id": "EmiQXNLZm8z-",
    "outputId": "6dcf9f92-ec49-4dc8-97ba-b466d887f4e1"
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Requirement already satisfied: pip in /opt/conda/lib/python3.7/site-packages (22.1.1)\n"
     ]
    }
   ],
   "source": [
    "!pip install --upgrade pip"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "hiBxgnc-m8-X"
   },
   "source": [
    "### Install TensorFlow Transform"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/"
    },
    "id": "j2CTKbMNm9I4",
    "outputId": "273dbe77-793d-4ba8-aa7c-7ed004ee6f93"
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\u001b[31mERROR: pip's dependency resolver does not currently take into account all the packages that are installed. This behaviour is the source of the following dependency conflicts.\n",
      "tensorflow-io 0.21.0 requires tensorflow<2.7.0,>=2.6.0, but you have tensorflow 2.8.2 which is incompatible.\n",
      "tensorflow-io 0.21.0 requires tensorflow-io-gcs-filesystem==0.21.0, but you have tensorflow-io-gcs-filesystem 0.26.0 which is incompatible.\n",
      "google-cloud-storage 2.3.0 requires google-cloud-core<3.0dev,>=2.3.0, but you have google-cloud-core 1.7.2 which is incompatible.\n",
      "cloud-tpu-client 0.10 requires google-api-python-client==1.8.0, but you have google-api-python-client 1.12.11 which is incompatible.\u001b[0m\u001b[31m\n",
      "\u001b[0m"
     ]
    }
   ],
   "source": [
    "!pip install -q -U tensorflow_transform"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "**Restart** the kernel to use updated packages. (On the Notebook menu, select Kernel > Restart Kernel > Restart)."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/"
    },
    "id": "R0mXLOJR_-dv",
    "outputId": "308976f5-1396-4bdf-9a9c-50b7784878d8"
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "<module 'pkg_resources' from '/opt/conda/lib/python3.7/site-packages/pkg_resources/__init__.py'>"
      ]
     },
     "execution_count": 1,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# This cell is only necessary because packages were installed while python was running.\n",
    "import pkg_resources\n",
    "import importlib\n",
    "\n",
    "importlib.reload(pkg_resources)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "RptgLn2RYuK3"
   },
   "source": [
    "## Imports"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {
    "id": "K4QXVIM7iglN"
   },
   "outputs": [],
   "source": [
    "import pathlib\n",
    "import pprint\n",
    "import tempfile\n",
    "\n",
    "import tensorflow as tf\n",
    "import tensorflow_transform as tft\n",
    "\n",
    "import tensorflow_transform.beam as tft_beam\n",
    "from tensorflow_transform.tf_metadata import dataset_metadata\n",
    "from tensorflow_transform.tf_metadata import schema_utils"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "CxOxaaOYRfl7"
   },
   "source": [
    "## Data: Create some dummy data\n",
    "You'll create some simple dummy data for your simple example:\n",
    "\n",
    "* `raw_data` is the initial raw data that you're going to preprocess\n",
    "* `raw_data_metadata` contains the schema that tells us the types of each of the columns in `raw_data`.  In this case, it's very simple."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {
    "id": "-R236Tkf_ON3"
   },
   "outputs": [],
   "source": [
    "raw_data = [\n",
    "      {'x': 1, 'y': 1, 's': 'hello'},\n",
    "      {'x': 2, 'y': 2, 's': 'world'},\n",
    "      {'x': 3, 'y': 3, 's': 'hello'}\n",
    "  ]\n",
    "\n",
    "raw_data_metadata = dataset_metadata.DatasetMetadata(\n",
    "    schema_utils.schema_from_feature_spec({\n",
    "        'y': tf.io.FixedLenFeature([], tf.float32),\n",
    "        'x': tf.io.FixedLenFeature([], tf.float32),\n",
    "        's': tf.io.FixedLenFeature([], tf.string),\n",
    "    }))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "xexbWmQEBUBZ"
   },
   "source": [
    "## Transform: Create a preprocessing function\n",
    "\n",
    "The _preprocessing function_ is the most important concept of tf.Transform. A preprocessing function is where the transformation of the dataset really happens. It accepts and returns a dictionary of tensors, where a tensor means a <a target='_blank' href='https://www.tensorflow.org/versions/r1.15/api_docs/python/tf/Tensor'><code>Tensor</code></a> or <a target='_blank' href='https://www.tensorflow.org/versions/r1.15/api_docs/python/tf/SparseTensor'><code>SparseTensor</code></a>. There are two main groups of API calls that typically form the heart of a preprocessing function:"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "Zadh6MXLS3eD"
   },
   "source": [
    "1. **TensorFlow Ops:** Any function that accepts and returns tensors, which usually means TensorFlow ops. These add TensorFlow operations to the graph that transforms raw data into transformed data one feature vector at a time.  These will run for every example, during both training and serving.\n",
    "2. **Tensorflow Transform Analyzers/Mappers:** Any of the analyzers/mappers provided by tf.Transform. These also accept and return tensors, and typically contain a combination of Tensorflow ops and Beam computation, but unlike TensorFlow ops they only run in the Beam pipeline during analysis requiring a full pass over the entire training dataset. The Beam computation runs only once, (prior to training, during analysis), and typically makes a full pass over the entire training dataset. They create `tf.constant` tensors, which are added to your graph. For example, `tft.min` computes the minimum of a tensor over the training dataset.\n",
    "\n",
    "Caution: When you apply your preprocessing function to serving inferences, the constants that were created by analyzers during training do not change.  If your data has trend or seasonality components, plan accordingly.\n",
    "\n",
    "Note: The `preprocessing_fn` is not directly callable. This means that\n",
    "calling `preprocessing_fn(raw_data)` will not work. Instead, it must\n",
    "be passed to the Transform Beam API as shown in the following cells."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {
    "id": "H2wANNF_2dCR"
   },
   "outputs": [],
   "source": [
    "def preprocessing_fn(inputs):\n",
    "    \"\"\"Preprocess input columns into transformed columns.\"\"\"\n",
    "    x = inputs['x']\n",
    "    y = inputs['y']\n",
    "    s = inputs['s']\n",
    "    x_centered = x - tft.mean(x)\n",
    "    y_normalized = tft.scale_to_0_1(y)\n",
    "    s_integerized = tft.compute_and_apply_vocabulary(s)\n",
    "    x_centered_times_y_normalized = (x_centered * y_normalized)\n",
    "    return {\n",
    "        'x_centered': x_centered,\n",
    "        'y_normalized': y_normalized,\n",
    "        's_integerized': s_integerized,\n",
    "        'x_centered_times_y_normalized': x_centered_times_y_normalized,\n",
    "    }"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "cSl9qyTCbBKR"
   },
   "source": [
    "## Syntax\n",
    "\n",
    "You're almost ready to put everything together and use <a target='_blank' href='https://beam.apache.org/'>Apache Beam</a> to run it.\n",
    "\n",
    "Apache Beam uses a <a target='_blank' href='https://beam.apache.org/documentation/programming-guide/#applying-transforms'>special syntax to define and invoke transforms</a>.  For example, in this line:\n",
    "\n",
    "```\n",
    "result = pass_this | 'name this step' >> to_this_call\n",
    "```\n",
    "\n",
    "The method `to_this_call` is being invoked and passed the object called `pass_this`, and <a target='_blank' href='https://stackoverflow.com/questions/50519662/what-does-the-redirection-mean-in-apache-beam-python'>this operation will be referred to as `name this step` in a stack trace</a>.  The result of the call to `to_this_call` is returned in `result`.  You will often see stages of a pipeline chained together like this:\n",
    "\n",
    "```\n",
    "result = apache_beam.Pipeline() | 'first step' >> do_this_first() | 'second step' >> do_this_last()\n",
    "```\n",
    "\n",
    "and since that started with a new pipeline, you can continue like this:\n",
    "\n",
    "```\n",
    "next_result = result | 'doing more stuff' >> another_function()\n",
    "```"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "5kLDSxOQ8xgg"
   },
   "source": [
    "## Putting it all together\n",
    "Now you're ready to transform your data. You'll use Apache Beam with a direct runner, and supply three inputs:\n",
    "\n",
    "1. `raw_data` - The raw input data that you created above\n",
    "2. `raw_data_metadata` - The schema for the raw data\n",
    "3. `preprocessing_fn` - The function that you created to do your transformation"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {
    "id": "mAF9w7RTZU7c"
   },
   "outputs": [],
   "source": [
    "def main(output_dir):\n",
    "  # Ignore the warnings\n",
    "  with tft_beam.Context(temp_dir=tempfile.mkdtemp()):\n",
    "    transformed_dataset, transform_fn = (  # pylint: disable=unused-variable\n",
    "        (raw_data, raw_data_metadata) | tft_beam.AnalyzeAndTransformDataset(\n",
    "            preprocessing_fn))\n",
    "\n",
    "  transformed_data, transformed_metadata = transformed_dataset  # pylint: disable=unused-variable\n",
    "\n",
    "  # Save the transform_fn to the output_dir\n",
    "  _ = # TODO 1: Your code goes here \n",
    "\n",
    "  return transformed_data, transformed_metadata"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/",
     "height": 968
    },
    "id": "zZPQl0X19ni2",
    "outputId": "f563a107-3977-45ce-b7d2-d6758b86e3c1"
   },
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "WARNING:apache_beam.runners.interactive.interactive_environment:Dependencies required for Interactive Beam PCollection visualization are not available, please use: `pip install apache-beam[interactive]` to install necessary dependencies to enable all data visualization features.\n"
     ]
    },
    {
     "data": {
      "application/javascript": [
       "\n",
       "        if (typeof window.interactive_beam_jquery == 'undefined') {\n",
       "          var jqueryScript = document.createElement('script');\n",
       "          jqueryScript.src = 'https://code.jquery.com/jquery-3.4.1.slim.min.js';\n",
       "          jqueryScript.type = 'text/javascript';\n",
       "          jqueryScript.onload = function() {\n",
       "            var datatableScript = document.createElement('script');\n",
       "            datatableScript.src = 'https://cdn.datatables.net/1.10.20/js/jquery.dataTables.min.js';\n",
       "            datatableScript.type = 'text/javascript';\n",
       "            datatableScript.onload = function() {\n",
       "              window.interactive_beam_jquery = jQuery.noConflict(true);\n",
       "              window.interactive_beam_jquery(document).ready(function($){\n",
       "                \n",
       "              });\n",
       "            }\n",
       "            document.head.appendChild(datatableScript);\n",
       "          };\n",
       "          document.head.appendChild(jqueryScript);\n",
       "        } else {\n",
       "          window.interactive_beam_jquery(document).ready(function($){\n",
       "            \n",
       "          });\n",
       "        }"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "WARNING:tensorflow:You are passing instance dicts and DatasetMetadata to TFT which will not provide optimal performance. Consider following the TFT guide to upgrade to the TFXIO format (Apache Arrow RecordBatch).\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "WARNING:tensorflow:You are passing instance dicts and DatasetMetadata to TFT which will not provide optimal performance. Consider following the TFT guide to upgrade to the TFXIO format (Apache Arrow RecordBatch).\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "WARNING:tensorflow:From /opt/conda/lib/python3.7/site-packages/tensorflow_transform/tf_utils.py:326: Tensor.experimental_ref (from tensorflow.python.framework.ops) is deprecated and will be removed in a future version.\n",
      "Instructions for updating:\n",
      "Use ref() instead.\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "2022-05-30 10:33:39.177277: W tensorflow/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libcuda.so.1'; dlerror: libcuda.so.1: cannot open shared object file: No such file or directory; LD_LIBRARY_PATH: /usr/local/cuda/lib64:/usr/local/nccl2/lib:/usr/local/cuda/extras/CUPTI/lib64\n",
      "2022-05-30 10:33:39.177321: W tensorflow/stream_executor/cuda/cuda_driver.cc:269] failed call to cuInit: UNKNOWN ERROR (303)\n",
      "WARNING:tensorflow:From /opt/conda/lib/python3.7/site-packages/tensorflow_transform/tf_utils.py:326: Tensor.experimental_ref (from tensorflow.python.framework.ops) is deprecated and will be removed in a future version.\n",
      "Instructions for updating:\n",
      "Use ref() instead.\n",
      "2022-05-30 10:33:39.177351: I tensorflow/stream_executor/cuda/cuda_diagnostics.cc:156] kernel driver does not appear to be running on this host (tensorflow-2-6-20220530-155417): /proc/driver/nvidia/version does not exist\n",
      "2022-05-30 10:33:39.177574: I tensorflow/core/platform/cpu_feature_guard.cc:151] This TensorFlow binary is optimized with oneAPI Deep Neural Network Library (oneDNN) to use the following CPU instructions in performance-critical operations:  AVX2 FMA\n",
      "To enable them in other operations, rebuild TensorFlow with the appropriate compiler flags.\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "WARNING:tensorflow:You are passing instance dicts and DatasetMetadata to TFT which will not provide optimal performance. Consider following the TFT guide to upgrade to the TFXIO format (Apache Arrow RecordBatch).\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "WARNING:tensorflow:You are passing instance dicts and DatasetMetadata to TFT which will not provide optimal performance. Consider following the TFT guide to upgrade to the TFXIO format (Apache Arrow RecordBatch).\n",
      "WARNING:apache_beam.options.pipeline_options:Discarding unparseable args: ['/opt/conda/lib/python3.7/site-packages/ipykernel_launcher.py', '-f', '/home/jupyter/.local/share/jupyter/runtime/kernel-79f36b66-7017-410e-b59d-ea7fe3eab699.json']\n",
      "WARNING:root:Make sure that locally built Python SDK docker image has Python 3.7 interpreter.\n",
      "2022-05-30 10:33:41.596519: W tensorflow/python/util/util.cc:368] Sets are not currently considered sequences, but this may change in the future, so consider avoiding using them.\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:Assets written to: /tmp/tmpgi5plrb6/tftransform_tmp/30a2d7fbfd4b44ba822d535a1d2b3352/assets\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:Assets written to: /tmp/tmpgi5plrb6/tftransform_tmp/30a2d7fbfd4b44ba822d535a1d2b3352/assets\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:struct2tensor is not available.\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:struct2tensor is not available.\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:tensorflow_decision_forests is not available.\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:tensorflow_decision_forests is not available.\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:tensorflow_text is not available.\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:tensorflow_text is not available.\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:Assets written to: /tmp/tmpgi5plrb6/tftransform_tmp/0783ab20b8044ad3971da61d4ded5f43/assets\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:Assets written to: /tmp/tmpgi5plrb6/tftransform_tmp/0783ab20b8044ad3971da61d4ded5f43/assets\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:struct2tensor is not available.\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:struct2tensor is not available.\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:tensorflow_decision_forests is not available.\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:tensorflow_decision_forests is not available.\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:tensorflow_text is not available.\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:tensorflow_text is not available.\n",
      "WARNING:apache_beam.options.pipeline_options:Discarding unparseable args: ['/opt/conda/lib/python3.7/site-packages/ipykernel_launcher.py', '-f', '/home/jupyter/.local/share/jupyter/runtime/kernel-79f36b66-7017-410e-b59d-ea7fe3eab699.json']\n",
      "WARNING:root:Make sure that locally built Python SDK docker image has Python 3.7 interpreter.\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "Raw data:\n",
      "[{'s': 'hello', 'x': 1, 'y': 1},\n",
      " {'s': 'world', 'x': 2, 'y': 2},\n",
      " {'s': 'hello', 'x': 3, 'y': 3}]\n",
      "\n",
      "Transformed data:\n",
      "[{'s_integerized': 0,\n",
      "  'x_centered': -1.0,\n",
      "  'x_centered_times_y_normalized': -0.0,\n",
      "  'y_normalized': 0.0},\n",
      " {'s_integerized': 1,\n",
      "  'x_centered': 0.0,\n",
      "  'x_centered_times_y_normalized': 0.0,\n",
      "  'y_normalized': 0.5},\n",
      " {'s_integerized': 0,\n",
      "  'x_centered': 1.0,\n",
      "  'x_centered_times_y_normalized': 1.0,\n",
      "  'y_normalized': 1.0}]\n"
     ]
    }
   ],
   "source": [
    "output_dir = pathlib.Path(tempfile.mkdtemp())\n",
    "\n",
    "transformed_data, transformed_metadata = main(str(output_dir))\n",
    "\n",
    "print('\\nRaw data:\\n{}\\n'.format(pprint.pformat(raw_data)))\n",
    "print('Transformed data:\\n{}'.format(pprint.pformat(transformed_data)))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "NO6LyTneNndy"
   },
   "source": [
    "### Is this the right answer?\n",
    "Previously, you used `tf.Transform` to do this:\n",
    "```\n",
    "x_centered = x - tft.mean(x)\n",
    "y_normalized = tft.scale_to_0_1(y)\n",
    "s_integerized = tft.compute_and_apply_vocabulary(s)\n",
    "x_centered_times_y_normalized = (x_centered * y_normalized)\n",
    "```\n",
    "\n",
    "* **x_centered** - With input of `[1, 2, 3]` the mean of x is 2, and you subtract it from x to center your x values at 0. So your result of `[-1.0, 0.0, 1.0]` is correct.\n",
    "* **y_normalized** - You wanted to scale your y values between 0 and 1.  Your input was `[1, 2, 3]` so your result of `[0.0, 0.5, 1.0]` is correct.\n",
    "* **s_integerized** - You wanted to map your strings to indexes in a vocabulary, and there were only 2 words in your vocabulary (\"hello\" and \"world\"). So with input of `[\"hello\", \"world\", \"hello\"]` your result of `[0, 1, 0]` is correct. Since \"hello\" occurs most frequently in this data, it will be the first entry in the vocabulary.\n",
    "* **x_centered_times_y_normalized** - You wanted to create a new feature by crossing `x_centered` and `y_normalized` using multiplication.  Note that this multiplies the results, not the original values, and your new result of `[-0.0, 0.0, 1.0]` is correct."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "dXw790Sr8Jws"
   },
   "source": [
    "## Use the resulting `transform_fn` directory"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/"
    },
    "id": "We4Mafrq8id6",
    "outputId": "14594a42-325e-47af-eb50-e399074076ae"
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "total 8\n",
      "drwxr-xr-x 4 jupyter jupyter 4096 May 30 10:33 transform_fn\n",
      "drwxr-xr-x 2 jupyter jupyter 4096 May 30 10:33 transformed_metadata\n"
     ]
    }
   ],
   "source": [
    "!ls -l {output_dir}"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "SoaaAXxk_vWP"
   },
   "source": [
    "The `transform_fn/` directory contains a `tf.saved_model` implementing with all the constants tensorflow-transform analysis results built into the graph. \n",
    "\n",
    "It is possible to load this directly with `tf.saved_model.load`, but this not easy to use:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/"
    },
    "id": "cz8dqFW6ANJQ",
    "outputId": "ab1f1f09-efc1-4331-d6ea-5d13f04cafa1"
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "<ConcreteFunction signature_wrapper(*, inputs_2, inputs, inputs_1) at 0x7F0E0AEC1C10>"
      ]
     },
     "execution_count": 8,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# Load a SavedModel from export_dir\n",
    "loaded = # TODO 2: Your code goes here\n",
    "loaded.signatures['serving_default']"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "ZCugaxMiBosA"
   },
   "source": [
    "A better approach is to load it using `tft.TFTransformOutput`. The `TFTransformOutput.transform_features_layer` method returns a `tft.TransformFeaturesLayer` object that can be used to apply the transformation:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/"
    },
    "id": "HNd4r2gJ75nx",
    "outputId": "962addb4-8b08-48b3-87f9-2dfae025bd23"
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:struct2tensor is not available.\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:struct2tensor is not available.\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:tensorflow_decision_forests is not available.\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:tensorflow_decision_forests is not available.\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:tensorflow_text is not available.\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:tensorflow_text is not available.\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "<tensorflow_transform.output_wrapper.TransformFeaturesLayer at 0x7f0e0ad2dfd0>"
      ]
     },
     "execution_count": 9,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "tf_transform_output = tft.TFTransformOutput(output_dir)\n",
    "\n",
    "tft_layer = tf_transform_output.transform_features_layer()\n",
    "tft_layer"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "se-M1zx49kTY"
   },
   "source": [
    "This `tft.TransformFeaturesLayer` expects a dictionary of batched features. So create a `Dict[str, tf.Tensor]` from the `List[Dict[str, Any]]` in `raw_data`:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {
    "id": "2nyE1fVj82Gp"
   },
   "outputs": [],
   "source": [
    "raw_data_batch = {\n",
    "    's': tf.constant([ex['s'] for ex in raw_data]),\n",
    "    'x': tf.constant([ex['x'] for ex in raw_data], dtype=tf.float32),\n",
    "    'y': tf.constant([ex['y'] for ex in raw_data], dtype=tf.float32),\n",
    "}"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "016sJ_cD_gVC"
   },
   "source": [
    "You can use the `tft.TransformFeaturesLayer` on it's own:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/"
    },
    "id": "fIXJYE0Z9Mrs",
    "outputId": "27dc572a-a671-443b-b712-d5312b0182b2"
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "{'y_normalized': array([0. , 0.5, 1. ], dtype=float32),\n",
       " 's_integerized': array([0, 1, 0]),\n",
       " 'x_centered_times_y_normalized': array([-0.,  0.,  1.], dtype=float32),\n",
       " 'x_centered': array([-1.,  0.,  1.], dtype=float32)}"
      ]
     },
     "execution_count": 11,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "transformed_batch = tft_layer(raw_data_batch)\n",
    "\n",
    "{key: value.numpy() for key, value in transformed_batch.items()}"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "FBfO5cp-8pqb"
   },
   "source": [
    "## Export"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "B3SN7D-FzrZ3"
   },
   "source": [
    "A more typical use case would use `tf.Transform` to apply the transformation to the training and evaluation datasets (see the [next tutorial](census.ipynb) for an example). Then, after training, before exporting the model attach the `tft.TransformFeaturesLayer` as the first layer so that you can export it as part of your `tf.saved_model`. For a concrete example, keep reading."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "lYV_oy5s9Dn9"
   },
   "source": [
    "### An example training model\n",
    "\n",
    "Below is a model that:\n",
    "\n",
    "1. takes the transformed batch,\n",
    "2. stacks them all together into a simple `(batch, features)` matrix,\n",
    "3. runs them through a few dense layers, and\n",
    "4. produces 10 linear outputs.\n",
    "\n",
    "In a real use case you would apply a one-hot to the `s_integerized` feature.\n",
    "\n",
    "You could train this model on a dataset transformed by `tf.Transform`:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {
    "id": "xWiEo1ZUzp4x"
   },
   "outputs": [],
   "source": [
    "class StackDict(tf.keras.layers.Layer):\n",
    "  def call(self, inputs):\n",
    "    values = [\n",
    "        tf.cast(v, tf.float32)\n",
    "        for k,v in sorted(inputs.items(), key=lambda kv: kv[0])]\n",
    "    return tf.stack(values, axis=1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {
    "id": "A0QJpoWT1aUD"
   },
   "outputs": [],
   "source": [
    "class TrainedModel(tf.keras.Model):\n",
    "  def __init__(self):\n",
    "    super().__init__(self)\n",
    "    self.concat = StackDict()\n",
    "    self.body = tf.keras.Sequential([\n",
    "        tf.keras.layers.Dense(64, activation='relu'),\n",
    "        tf.keras.layers.Dense(64, activation='relu'),\n",
    "        tf.keras.layers.Dense(10),\n",
    "    ])\n",
    "\n",
    "  def call(self, inputs, training=None):\n",
    "    x = self.concat(inputs)\n",
    "    return self.body(x, training)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {
    "id": "DkMwREIx2fkD"
   },
   "outputs": [],
   "source": [
    "trained_model = TrainedModel()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "uBwnbh1Q-TBK"
   },
   "source": [
    "Imagine you trained the model.\n",
    "\n",
    "```\n",
    "trained_model.compile(loss=..., optimizer='adam')\n",
    "trained_model.fit(...)\n",
    "```"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "notFxUC0AFs6"
   },
   "source": [
    "This model runs on the transformed inputs"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/"
    },
    "id": "d2KJ8nGt228O",
    "outputId": "cf2ea5a7-6f26-4844-fbbb-0266a4e32a7d"
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "TensorShape([3, 10])"
      ]
     },
     "execution_count": 15,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "trained_model_output = trained_model(transformed_batch)\n",
    "trained_model_output.shape"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "fzWs35Ki6M5c"
   },
   "source": [
    "### An example export wrapper\n",
    "\n",
    "Imagine you've trained the above model and want to export it.\n",
    "\n",
    "You'll want to include the transform function in the exported model:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "metadata": {
    "id": "Pe-nbN123qUt"
   },
   "outputs": [],
   "source": [
    "class ExportModel(tf.Module):\n",
    "  def __init__(self, trained_model, input_transform):\n",
    "    self.trained_model = trained_model\n",
    "    self.input_transform = input_transform\n",
    "\n",
    "  @tf.function\n",
    "  def __call__(self, inputs, training=None):\n",
    "    x = self.input_transform(inputs)\n",
    "    return self.trained_model(x)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "metadata": {
    "id": "iLUIO-Y87AC0"
   },
   "outputs": [],
   "source": [
    "# Export the model\n",
    "export_model = # TODO 3: Your code goes here"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "sFDYQDgU7ozE"
   },
   "source": [
    "This combined model works on the raw data, and produces exactly the same results as calling the trained model directly:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/"
    },
    "id": "AqwHTex27ILk",
    "outputId": "220d5904-8697-4d0c-8232-e82a9a42f203"
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "TensorShape([3, 10])"
      ]
     },
     "execution_count": 18,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "export_model_output = export_model(raw_data_batch)\n",
    "export_model_output.shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/"
    },
    "id": "AZQ6_Dfd7xws",
    "outputId": "dafa50a4-153e-4846-935d-e01afc351293"
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "0.0"
      ]
     },
     "execution_count": 19,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "tf.reduce_max(abs(export_model_output - trained_model_output)).numpy()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "5r-lH_nh8PM-"
   },
   "source": [
    "This `export_model` includes the `tft.TransformFeaturesLayer` and is entierly self-contained. You can save it and restore it in another environment and still get exactly the same result:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 20,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/"
    },
    "id": "VK17CShl8F7s",
    "outputId": "60ff07cf-ccdb-4224-9f6d-b0dc39842154"
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:Assets written to: /tmp/tmp65n6_5yltft/assets\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:Assets written to: /tmp/tmp65n6_5yltft/assets\n"
     ]
    }
   ],
   "source": [
    "import tempfile\n",
    "model_dir = tempfile.mkdtemp(suffix='tft')\n",
    "\n",
    "tf.saved_model.save(export_model, model_dir)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 21,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/"
    },
    "id": "RTF-yRnA9yrL",
    "outputId": "b3e51ae6-2cbd-4a6b-ca48-9ba2577be41b",
    "tags": []
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "TensorShape([3, 10])"
      ]
     },
     "execution_count": 21,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "reloaded = tf.saved_model.load(model_dir)\n",
    "\n",
    "reloaded_model_output = reloaded(raw_data_batch)\n",
    "reloaded_model_output.shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 22,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/"
    },
    "id": "tFx1I6FQ9_mj",
    "outputId": "542ff9cf-7248-47d0-b4eb-888344f6a0d3"
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "0.0"
      ]
     },
     "execution_count": 22,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "tf.reduce_max(abs(export_model_output - reloaded_model_output)).numpy()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "colab": {
   "collapsed_sections": [
    "tghWegsjhpkt",
    "cSl9qyTCbBKR",
    "NO6LyTneNndy"
   ],
   "name": "simple.ipynb",
   "provenance": [],
   "toc_visible": true
  },
  "environment": {
   "kernel": "python3",
   "name": "tf2-gpu.2-6.m93",
   "type": "gcloud",
   "uri": "gcr.io/deeplearning-platform-release/tf2-gpu.2-6:m93"
  },
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.12"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
